{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Face Mask Detection using PaddlePaddle\n",
    "\n",
    "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleHub](https://github.com/PaddlePaddle/PaddleHub/tree/release/v1.5/demo/mask_detection/cpp) to do mask detection on the sample image. To complete this procedure, there are two steps needs to be done:\n",
    "\n",
    "- Recognize face on the image (no matter wearing mask or not) using Face object detection model\n",
    "- classify the face is wearing mask or not\n",
    "\n",
    "These two steps will involve two paddle models. We will implement the corresponding preprocess and postprocess logic to it.\n",
    "\n",
    "## Import dependencies and classes\n",
    "\n",
    "PaddlePaddle is one of the Deep Engines that requires DJL hybrid mode to run inference. Itself does not contains NDArray operations and needs a supplemental DL framework to help with that. So we import Pytorch DL engine as well in here to do the processing works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.12.0\n",
    "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.12.0\n",
    "%maven ai.djl.paddlepaddle:paddlepaddle-native-auto:2.0.2\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "\n",
    "// second engine to do preprocessing and postprocessing\n",
    "%maven ai.djl.pytorch:pytorch-engine:0.12.0\n",
    "%maven ai.djl.pytorch:pytorch-native-auto:1.8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.Application;\n",
    "import ai.djl.MalformedModelException;\n",
    "import ai.djl.ModelException;\n",
    "import ai.djl.inference.Predictor;\n",
    "import ai.djl.modality.Classifications;\n",
    "import ai.djl.modality.cv.*;\n",
    "import ai.djl.modality.cv.output.*;\n",
    "import ai.djl.modality.cv.transform.*;\n",
    "import ai.djl.modality.cv.translator.ImageClassificationTranslator;\n",
    "import ai.djl.modality.cv.util.NDImageUtils;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.ndarray.types.Shape;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.translate.*;\n",
    "\n",
    "import java.io.IOException;\n",
    "import java.nio.file.*;\n",
    "import java.util.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Face Detection model\n",
    "\n",
    "Now we can start working on the first model. The model can do face detection and require some additional processing before we feed into it:\n",
    "\n",
    "- Resize: Shrink the image with a certain ratio to feed in\n",
    "- Normalize the image with a scale\n",
    "\n",
    "Fortunatly, DJL offers a `Translator` interface that can help you with these processing. The rough Translator architecture looks like below:\n",
    "\n",
    "![](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n",
    "\n",
    "In the following sections, we will implement a `FaceTranslator` class to do the work.\n",
    "\n",
    "### Preprocessing\n",
    "\n",
    "In this stage, we will load an image and do some preprocessing work to it. Let's load the image first and take a look at it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String url = \"https://raw.githubusercontent.com/PaddlePaddle/PaddleHub/release/v1.5/demo/mask_detection/python/images/mask.jpg\";\n",
    "Image img = ImageFactory.getInstance().fromUrl(url);\n",
    "img.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, let's try to apply some transformation to it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList processImageInput(NDManager manager, Image input, float shrink) {\n",
    "            NDArray array = input.toNDArray(manager);\n",
    "            Shape shape = array.getShape();\n",
    "            array = NDImageUtils.resize(\n",
    "                            array, (int) (shape.get(1) * shrink), (int) (shape.get(0) * shrink));\n",
    "            array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW BGR -> RGB\n",
    "            NDArray mean = manager.create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));\n",
    "            array = array.sub(mean).mul(0.007843f); // normalization\n",
    "            array = array.expandDims(0); // make batch dimension\n",
    "            return new NDList(array);\n",
    "}\n",
    "\n",
    "processImageInput(NDManager.newBaseManager(), img, 0.5f);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see above, we convert the image to a NDArray with shape following (number_of_batches, channel (RGB), height, width). This is the required input for the model to run object detection.\n",
    "\n",
    "### Postprocessing\n",
    "\n",
    "For postprocessing, The output is in shape of (number_of_boxes, (class_id, probability, xmin, ymin, xmax, ymax)). We can store them into the prebuilt DJL `DetectedObjects` classes for further processing. Let's assume we have an inference output of ((1, 0.99, 0.2, 0.4, 0.5, 0.8)) and try to draw this box out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "DetectedObjects processImageOutput(NDList list, List<String> className, float threshold) {\n",
    "            NDArray result = list.singletonOrThrow();\n",
    "            float[] probabilities = result.get(\":,1\").toFloatArray();\n",
    "            List<String> names = new ArrayList<>();\n",
    "            List<Double> prob = new ArrayList<>();\n",
    "            List<BoundingBox> boxes = new ArrayList<>();\n",
    "            for (int i = 0; i < probabilities.length; i++) {\n",
    "                if (probabilities[i] >= threshold) {\n",
    "                    float[] array = result.get(i).toFloatArray();\n",
    "                    names.add(className.get((int) array[0]));\n",
    "                    prob.add((double) probabilities[i]);\n",
    "                    boxes.add(\n",
    "                            new Rectangle(\n",
    "                                    array[2], array[3], array[4] - array[2], array[5] - array[3]));\n",
    "                }\n",
    "            }\n",
    "            return new DetectedObjects(names, prob, boxes);\n",
    "}\n",
    "\n",
    "NDArray tempOutput = NDManager.newBaseManager().create(new float[]{1f, 0.99f, 0.1f, 0.1f, 0.2f, 0.2f}, new Shape(1, 6));\n",
    "DetectedObjects testBox = processImageOutput(new NDList(tempOutput), Arrays.asList(\"Not Face\", \"Face\"), 0.7f);\n",
    "Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n",
    "newImage.drawBoundingBoxes(testBox);\n",
    "newImage.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Translator and run inference\n",
    "\n",
    "After this step, you might understand how process and postprocess works in DJL. Now, let's do something real and put them together in a single piece:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FaceTranslator implements Translator<Image, DetectedObjects> {\n",
    "\n",
    "        private float shrink;\n",
    "        private float threshold;\n",
    "        private List<String> className;\n",
    "\n",
    "        FaceTranslator(float shrink, float threshold) {\n",
    "            this.shrink = shrink;\n",
    "            this.threshold = threshold;\n",
    "            className = Arrays.asList(\"Not Face\", \"Face\");\n",
    "        }\n",
    "\n",
    "        @Override\n",
    "        public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {\n",
    "            return processImageOutput(list, className, threshold);\n",
    "        }\n",
    "\n",
    "        @Override\n",
    "        public NDList processInput(TranslatorContext ctx, Image input) {\n",
    "            return processImageInput(ctx.getNDManager(), input, shrink);\n",
    "        }\n",
    "\n",
    "        @Override\n",
    "        public Batchifier getBatchifier() {\n",
    "            return null;\n",
    "        }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run inference with this model, we need to load the model from Paddle model zoo. To load a model in DJL, you need to specify a `Crieteria`. `Crieteria` is used identify where to load the model and which `Translator` should apply to it. Then, all we need to do is to get a `Predictor` from the model and use it to do inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Criteria<Image, DetectedObjects> criteria =\n",
    "        Criteria.builder()\n",
    "                .optApplication(Application.CV.OBJECT_DETECTION)\n",
    "                .setTypes(Image.class, DetectedObjects.class)\n",
    "                .optArtifactId(\"face_detection\")\n",
    "                .optTranslator(new FaceTranslator(0.5f, 0.7f))\n",
    "                .optFilter(\"flavor\", \"server\")\n",
    "                .build();\n",
    "   \n",
    "var model = criteria.loadModel();\n",
    "var predictor = model.newPredictor();\n",
    "\n",
    "DetectedObjects inferenceResult = predictor.predict(img);\n",
    "newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n",
    "newImage.drawBoundingBoxes(inferenceResult);\n",
    "newImage.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see above, it brings you three faces detections.\n",
    "\n",
    "## Mask Classification model\n",
    "\n",
    "\n",
    "So, once we have the image location ready, we can crop the image and feed it to the Mask Classification model for further processing.\n",
    "\n",
    "### Crop the image\n",
    "\n",
    "The output of the box location is a value from 0 - 1 that can be mapped to the actual box pixel location if we simply multiply by width/height. For better accuracy on the cropped image, we extend the detection box to square. Let's try to get a cropped image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int[] extendSquare(\n",
    "        double xmin, double ymin, double width, double height, double percentage) {\n",
    "    double centerx = xmin + width / 2;\n",
    "    double centery = ymin + height / 2;\n",
    "    double maxDist = Math.max(width / 2, height / 2) * (1 + percentage);\n",
    "    return new int[] {\n",
    "        (int) (centerx - maxDist), (int) (centery - maxDist), (int) (2 * maxDist)\n",
    "    };\n",
    "}\n",
    "\n",
    "Image getSubImage(Image img, BoundingBox box) {\n",
    "    Rectangle rect = box.getBounds();\n",
    "    int width = img.getWidth();\n",
    "    int height = img.getHeight();\n",
    "    int[] squareBox =\n",
    "            extendSquare(\n",
    "                    rect.getX() * width,\n",
    "                    rect.getY() * height,\n",
    "                    rect.getWidth() * width,\n",
    "                    rect.getHeight() * height,\n",
    "                    0.18);\n",
    "    return img.getSubimage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);\n",
    "}\n",
    "\n",
    "List<DetectedObjects.DetectedObject> faces = inferenceResult.items();\n",
    "getSubImage(img, faces.get(2).getBoundingBox()).getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare Translator and load the model\n",
    "\n",
    "For the face classification model, we can use DJL prebuilt `ImageClassificationTranslator` with a few transformation. This Translator brings a basic image translation process and can be extended with additional standard processing steps. So in our case, we don't have to create another `Translator` and just leverage on this prebuilt one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria = Criteria.builder()\n",
    "                        .optApplication(Application.CV.IMAGE_CLASSIFICATION)\n",
    "                        .setTypes(Image.class, Classifications.class)\n",
    "                        .optTranslator(\n",
    "                                ImageClassificationTranslator.builder()\n",
    "                                        .addTransform(new Resize(128, 128))\n",
    "                                        .addTransform(new ToTensor()) // HWC -> CHW div(255)\n",
    "                                        .addTransform(\n",
    "                                                new Normalize(\n",
    "                                                        new float[] {0.5f, 0.5f, 0.5f},\n",
    "                                                        new float[] {1.0f, 1.0f, 1.0f}))\n",
    "                                        .addTransform(nd -> nd.flip(0)) // RGB -> GBR\n",
    "                                        .build())\n",
    "                        .optArtifactId(\"mask_classification\")\n",
    "                        .optFilter(\"flavor\", \"server\")\n",
    "                        .build();\n",
    "\n",
    "var classifyModel = criteria.loadModel();\n",
    "var classifier = classifyModel.newPredictor();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run inference\n",
    "\n",
    "So all we need to do is to apply the previous implemented functions and apply them all together. We firstly crop the image and then use it for inference. After these steps, we create a new DetectedObjects with new Classification classes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "List<String> names = new ArrayList<>();\n",
    "List<Double> prob = new ArrayList<>();\n",
    "List<BoundingBox> rect = new ArrayList<>();\n",
    "for (DetectedObjects.DetectedObject face : faces) {\n",
    "    Image subImg = getSubImage(img, face.getBoundingBox());\n",
    "    Classifications classifications = classifier.predict(subImg);\n",
    "    names.add(classifications.best().getClassName());\n",
    "    prob.add(face.getProbability());\n",
    "    rect.add(face.getBoundingBox());\n",
    "}\n",
    "\n",
    "newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n",
    "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n",
    "newImage.getWrappedImage();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "11.0.5+10-LTS"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
